#! /usr/bin/python
# this script generates an ensemble of SAXS measures for random tetrahedra, double disks and circles,
# and to compare to experimental data
# launch: ./generator_with_Nter.py modloop17_7_cter.pdb model_Nter_minimized.pdb conc1.dat saxs

import biobox as bb

from scipy import interpolate
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.cross_validation import cross_val_score
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import sys, os
from random import random
from copy import deepcopy

### helper functions ###

def vdw_energy(m1, m2):

        epsilon = 1.0
        sigma = 2.7
        cutoff = 12.0

        energy = 0.0
        d = []
        for i in xrange(0, len(m1), 1):
            d.append(np.sqrt(np.sum((m2-m1[i])**2, axis=1)))

        dist = np.array(d)

        #detect interfacing atoms (atom couples at less than a certain cutoff distance
        couples = np.array(np.where(dist<cutoff)) #detect couples of clashing residues

        for i in xrange(0, len(couples[0]), 1):
            d = dist[couples[0, i], couples[1, i]]
            energy += 4*epsilon*((sigma/d)**9-(sigma/d)**6)

        return energy


# SAXS signal smoothing function
def convolve(signal, degree):

        window = degree*2+1
        k = np.array([1.0]*window)

        for i in range(window):
                frac = (i-degree)/float(window)
                k[i] = 1/(np.exp((4*(frac))**2))
                        
        k /= np.sum(k)
        s = np.convolve(signal, k, 'same')
	return s


#prepare model dataset. Align with experiment sampling points
def prepare_dataset(saxs_1, saxs_axis, first, last):

	s1 = []
	for s in saxs_1:

		f = interpolate.interp1d(axis_1, s)
		saxs_interp = f(saxs_axis)
		saxs_interp /= saxs_interp[first]
		s1.append(saxs_interp[first:last])

	return np.log(np.array(s1))

#######################################################################################

#parameters
building_block = sys.argv[1] #dimer PDB file
N_ter_building_block = sys.argv[2] #NTER dimer
filename = sys.argv[3] #exp. signal

#flag indicating whether exp. signal has
# to be calculated from a reference PDB, or extacted from a file
ref_type = sys.argv[4] 
if ref_type != "pdb" and ref_type != "saxs":
    print "ERROR: expected reference either pdb or saxs"
    sys.exit()

amount = 1000 #quantity of models per topology
training = 800 #amount of models using for training (should be smaller than amount)
maxsaxs = 0.3 #maximal q value to consider for analysis
smoothing = 1 #smoothing of experimental signal
n_estimators = 50 # number of trees to generate

stoichiometry = 6

M = bb.Molecule()
M.import_pdb(building_block)
M.align_axes()

#NTER model building data structures
M0 = bb.Molecule()
M0.import_pdb(N_ter_building_block)
Nter = bb.Multimer()
Nter.setup_polyhedron("Tetrahedron", M0)
Nter.generate_polyhedron(10, 0, 0, 0)
labels1n = Nter.unit_labels.keys()[0]
labels2n = Nter.unit_labels.keys()[1:len(Nter.unit)]

classes=["tetrahedron", "double disk", "circle"]


#######################################################################################

#generate tetrahedra (with Nter)
try:
	print "> loading tetrahedra saxs and crds files..."
	saxs_1 = np.loadtxt("saxs_tetrahedra_Nter.dat")
	crds_1 = np.loadtxt("crds_tetrahedra_Nter.dat")
	axis_1 = np.loadtxt("axis_tetrahedra_Nter.dat")

except:
	Ptest = bb.Multimer()
	Ptest.setup_polyhedron("Tetrahedron", M)

	#extract the indices of atoms of interest is a single individual
	atoms=["CA"] # consider clashes with CA and CB atoms
	Ptest.generate_polyhedron(10, 0, 0, 0)
	labels1 = Ptest.unit_labels.keys()[0]
	labels2 = Ptest.unit_labels.keys()[1:len(Ptest.unit)]

	print "\n> creating %s energetically favorable tetrahedra"%amount

	saxs_val = []
        crds_val = []
	i = 0
	while i < amount:

		print "> attempting tetrahedron nb. %s"%i

		v1 = 30+random()*30
		v2 = 90+random()*180
		v3 = 90+random()*180
		v4 = 90+random()*180
		Ptest.generate_polyhedron(v1, v2, v3, v4)

		#extract atoms for clash detection, and compute energy
		m1 = Ptest.atomselect(labels1,"*","*",atoms)
		m2 = Ptest.atomselect(labels2,"*","*",atoms)
		energy = vdw_energy(m1,m2)

		#compute CCS if energy is negative (clash otherwise)
		if energy < 0:

			#build Nter model
			success = False
			k = 0
			while k < 50:
				v1n = 0+random()*v1
				v2n = 90+random()*180
				v3n = 90+random()*180
				v4n = 90+random()*180
				Nter.generate_polyhedron(v1n, v2n, v3n, v4n)

				m1n = Nter.atomselect(labels1n, "*", "*", atoms)
				m2n = Nter.atomselect(labels2n, "*", "*", atoms)
				
				#check clashes within N-ter core and between core and assembly
				if vdw_energy(m1n, m2n)<0:
					k += 1
					if vdw_energy(m2, m2n) < 0:
						print ">> added a Nter core at %s th attempt!"%k
						success = True
						break


			if not success:
				print ">> aww, failed because of Nter core..."
				continue

			S = Nter.make_molecule()
			Ptest.append(S)

			crds_val.append([v1, v2, v3, v4, v1n, v2n, v3n, v4n])
			s = Ptest.saxs()
			saxs_val.append(s[:, 1])
			
			i += 1

		else:
			print ">> aww, clash detected..."

	saxs_1 = np.array(saxs_val)
	crds_1 = np.array(crds_val)
	axis_1 = s[:,0]

	np.savetxt("saxs_tetrahedra_Nter.dat", saxs_1)
	np.savetxt("crds_tetrahedra_Nter.dat", crds_1)
	np.savetxt("axis_tetrahedra_Nter.dat", axis_1)


#########################################################################

#generate double disks (with Nter)
try:
	print "> loading double disk saxs and crds files..."
	saxs_2 = np.loadtxt("saxs_%s_ddisk_Nter.dat"%stoichiometry)
	crds_2 = np.loadtxt("crds_%s_ddisk_Nter.dat"%stoichiometry)
	axis_2 = np.loadtxt("axis_%s_ddisk_Nter.dat"%stoichiometry)
except:

	Ptest = bb.Multimer()
	Ptest.load(M,stoichiometry)
	Ptest.make_prism(10,20,0,0,0,t=0)

	#extract the indices of atoms of interest is a single individual
	atoms = ["CA","CB"] # consider clashes with CA and CB atoms
	labels1 = Ptest.unit_labels.keys()[0]
	labels2 = Ptest.unit_labels.keys()[1:len(Ptest.unit)]
	print "\n> creating %s energetically favorable double disks"%amount

        dim = np.max(M.get_xyz(),axis=0)-np.min(M.get_xyz(),axis=0)
        maxsize = dim[0]/(2*np.tan(np.pi/(stoichiometry/2.0)))
        minsize = dim[2]/(2*np.tan(np.pi/(stoichiometry/2.0)))

        print minsize, maxsize,dim[2], dim[0]

	saxs_val = []
        crds_val = []
	i = 0
	while i < amount:

		print "> attempting double disk nb. %s"%i
		v1 = minsize+random()*maxsize
		v2 = dim[2]+random()*dim[0]
		v3 = 90+random()*180
		v4 = 90+random()*180
		v5 = 90+random()*180

		Ptest = bb.Multimer()
		Ptest.load(M,stoichiometry)
		Ptest.make_prism(v1,v2,v3,v4,v5,t=0)

		#extract atoms for clash detection, and compute energy
		m1 = Ptest.atomselect(labels1, "*", "*", atoms)
		m2 = Ptest.atomselect(labels2, "*", "*", atoms)
		energy = vdw_energy(m1, m2)

		#compute SAXS if energy is negative (clash otherwise)
		if energy < 0:

			#build Nter model
			success=False
			k=0
		        print ">> energy = %s for in ddisk %s %s %s %s %s!"%(energy,v1,v2,v3,v4,v5)
                        print ">> attempting core addition..."
			while k<100:
                                k += 1
				v1n = 0+random()*v1
				v2n = 0+random()*v2
				v3n = 90+random()*180
				v4n = 90+random()*180
				v5n = 90+random()*180

				Nter = bb.Multimer()
				Nter.load(M0,stoichiometry)
				Nter.make_prism(v1n, v2n, v3n, v4n, v5n, t=0)

				m1n = Nter.atomselect(labels1n, "*", "*", atoms)
				m2n = Nter.atomselect(labels2n, "*", "*", atoms)
 
				energy_2 = vdw_energy(m1n, m2n)
                                print ">>> try nb. %s, [%s, %s, %s, %s, %s] -> %s"%(k, v1n, v2n, v3n, v4n, v5n, energy_2)

				#check clashes within N-ter core and between core and assembly
				if energy_2 < 0:
                                        energy_3 = vdw_energy(m2, m2n)
					if energy_3 < 0:
						print ">> added a Nter core at %s th attempt!"%k
						success = True
						break

			if not success:
				print ">> aww, failed because of Nter core..."
				continue

			S = Nter.make_molecule()
			Ptest.append(S)

			crds_val.append([v1,v2,v3,v4,v5,v1n,v2n,v3n,v4n,v5n])
			s = Ptest.saxs()
			saxs_val.append(s[:,1])
			
			i+=1

		else:
			print ">> aww, clash detected..."

	saxs_2 = np.array(saxs_val)
	crds_2 = np.array(crds_val)
	axis_2 = s[:,0]

	np.savetxt("saxs_%s_ddisk_Nter.dat"%stoichiometry,saxs_2)
	np.savetxt("crds_%s_ddisk_Nter.dat"%stoichiometry,crds_2)
	np.savetxt("axis_%s_ddisk_Nter.dat"%stoichiometry,axis_2)


#######################################################################################

#generate circles
try:
	print "> loading double disk saxs and crds files..."
	saxs_3=np.loadtxt("saxs_%s_circle_Nter.dat"%stoichiometry)
	crds_3=np.loadtxt("crds_%s_circle_Nter.dat"%stoichiometry)
	axis_3=np.loadtxt("axis_%s_circle_Nter.dat"%stoichiometry)
except:

	Ptest = bb.Multimer()
	Ptest.load(M,stoichiometry)
	Ptest.rotate(0, 0, 0)
	Ptest.make_circular_symmetry(10)

	#extract the indices of atoms of interest is a single individual
	atoms = ["CA","CB"] # consider clashes with CA and CB atoms
	labels1 = Ptest.unit_labels.keys()[0]
	labels2 = Ptest.unit_labels.keys()[1:len(Ptest.unit)]
	print "\n> creating %s energetically favorable circles"%amount

        dim = np.max(M.get_xyz(), axis=0)-np.min(M.get_xyz(), axis=0)
        maxsize = dim[0]/(2*np.tan(np.pi/stoichiometry))-dim[2]/2.0
        minsize = dim[2]/(2*np.tan(np.pi/stoichiometry))-dim[2]/2.0-stoichiometry

	saxs_val = []
        crds_val = []
	i = 0
	while i < amount:

		print "> attempting double disk nb. %s"%i
		v1 = minsize+random()*maxsize
		v2 = 90+random()*180
		v3 = 90+random()*180
		v4 = 90+random()*180

		Ptest = bb.Multimer()
		Ptest.load(M, stoichiometry)
		Ptest.rotate(v2, v3, v4)
		Ptest.make_circular_symmetry(v1)

		#extract atoms for clash detection, and compute energy
		m1 = Ptest.atomselect(labels1, "*", "*", atoms)
		m2 = Ptest.atomselect(labels2, "*", "*", atoms)
		energy = vdw_energy(m1, m2)

		#compute SAXS if energy is negative (clash otherwise)
		if energy<0:

			#build Nter model
			success=False
			k=0
		        print ">> energy = %s for in ddisk %s %s %s %s!"%(energy,v1,v2,v3,v4)
                        print ">> attempting core addition..."
			while k < 100:
                                k += 1
				v1n = 0+random()*v1
				v2n = 90+random()*180
				v3n = 90+random()*180
				v4n = 90+random()*180

				Nter = bb.Multimer()
				Nter.load(M0,stoichiometry)
				Nter.rotate(v2n,v3n,v4n)
				Nter.make_circular_symmetry(v1n)

				m1n = Nter.atomselect(labels1n,"*","*",atoms)
				m2n = Nter.atomselect(labels2n,"*","*",atoms)
 
				energy_2 = vdw_energy(m1n, m2n)
                                print ">>> try nb. %s, [%s, %s, %s, %s] -> %s"%(k, v1n, v2n, v3n, v4n, energy_2)

				#check clashes within N-ter core and between core and assembly
				if energy_2<0:
                                        energy_3 = vdw_energy(m2, m2n)
					if energy_3 < 0:
						print ">> added a Nter core at %s th attempt!"%k
						success = True
						break

                                     
			if not success:
				print ">> aww, failed because of Nter core..."
				continue

			S = Nter.make_molecule()
			Ptest.append(S)

			crds_val.append([v1, v2, v3, v4, v1n, v2n, v3n, v4n])
			s = Ptest.saxs()
			saxs_val.append(s[:,1])
			
			i += 1

		else:
			print ">> aww, clash detected..."

	saxs_3 = np.array(saxs_val)
	crds_3 = np.array(crds_val)
	axis_3 = s[:, 0]

	np.savetxt("saxs_%s_circle_Nter.dat"%stoichiometry, saxs_3)
	np.savetxt("crds_%s_circle_Nter.dat"%stoichiometry, crds_3)
	np.savetxt("axis_%s_circle_Nter.dat"%stoichiometry, axis_3)


#######################################################################################
#######################################################################################

print "> loading x-ray and computing SAXS data"
M = bb.Molecule()

if ref_type == "pdb":
    M.import_pdb(filename)
    saxs_data_raw = M.saxs(crysol_options="-lm 50 -ns 2000")

else:
    #load experimental values
    saxs_data_raw = np.loadtxt(filename)

saxs_axis = deepcopy(saxs_data_raw[:,0])

#smooth, normalize, log scale and save experimental data
saxs_smooth = convolve(deepcopy(saxs_data_raw[:, 1]), smoothing)
first = np.argmax(saxs_smooth) #detect position of max in exp data
last = np.argmin(np.abs(saxs_axis-maxsaxs))
saxs_data = np.log(saxs_smooth[first:last]/saxs_smooth[first])
np.savetxt("saxs_experiment.dat", saxs_data)
print "> experimental dataset has %s measuring points"%len(saxs_smooth[first:last])

#prepare model datasets (slice, rebin according to experiment's axis, and log)
s1 = prepare_dataset(saxs_1, saxs_axis, first, last) #tetrahedron
s2 = prepare_dataset(saxs_2, saxs_axis, first, last) #double disk
s3 = prepare_dataset(saxs_3, saxs_axis, first, last) #circle

print "\n> cross correlation coefficients"
c = []
for i in xrange(0, len(s1), 1):
    c.append(np.corrcoef(s1[i], saxs_data)[0, 1])

cross1 = np.array(c)
print ">> tetrahedron. mean: %s, std: %s, max: %s, min: %s"%(np.mean(cross1), np.std(cross1), np.max(cross1), np.min(cross1))

c = []
for i in xrange(0, len(s2), 1):
    c.append(np.corrcoef(s2[i], saxs_data)[0, 1])

cross2 = np.array(c)
print ">> double disk. mean: %s, std: %s, max: %s, min: %s"%(np.mean(cross2), np.std(cross2), np.max(cross2), np.min(cross2))

c = []
for i in xrange(0, len(s3), 1):
    c.append(np.corrcoef(s3[i],saxs_data)[0, 1])

cross3 = np.array(c)
print ">> circle. mean: %s, std: %s, max: %s, min: %s\n"%(np.mean(cross3), np.std(cross3), np.max(cross3), np.min(cross3))

#create training and test sets (three classes)
train_set = np.concatenate((s1[0:training], s2[0:training], s3[0:training]))
train_class = np.concatenate((np.ones(training), np.ones(training)*2, np.ones(training)*3))
test_set = np.concatenate((s1[training:len(s1)], s2[training:len(s1)], s3[training:len(s1)]))
test_class = np.concatenate((np.ones(len(s1)-training), np.ones(len(s2)-training)*2, np.ones(len(s3)-training)*3))

#train random forest
B = ExtraTreesClassifier(n_estimators=n_estimators)
B.fit(train_set,train_class)

#determine features importance
importances = B.feature_importances_
std = np.std([tree.feature_importances_ for tree in B.estimators_],
             axis=0)

#test random forest
r = B.predict(test_set)
success = np.sum((r-test_class)==0)
print "> %s successful predictions on %s tests, i.e. %s percent"%(success,len(test_class),float(success)/len(test_class)*100)

print "> the following mistakes occurred:"
wrong = np.where((r-test_class)!=0)[0]
for k in wrong:
    print ">> expected: %i, predicted %i"%(test_class[k],r[k])


#predict architecture of experimental data
r = B.predict([saxs_data])
print "> prediction on experimental data: %s"%(classes[int(r[0])-1])

#######################################################################################

#4C6EBA: blue-ish, #79BA4C: green-ish, #FADA66: yellow-ish

#plot SAXS signals
mydpi = 120
fig = plt.figure(figsize=(800./mydpi,1000./mydpi), dpi=mydpi)
ax1 = fig.add_subplot(2, 1, 1)
l = []
l.append(ax1.plot(saxs_axis[first:last], s2[0,:], color="#79BA4C", alpha=0.9))
for i in xrange(1, len(s2), 1):
	ax1.plot(saxs_axis[first:last], s2[i,:], color="#79BA4C", alpha=0.1)

l.append(ax1.plot(saxs_axis[first:last], s3[0,:], color="#FADA66", alpha=0.9))
for i in xrange(1, len(s3), 1):
	ax1.plot(saxs_axis[first:last], s3[i,:], color="#FADA66", alpha=0.1)

l.append(ax1.plot(saxs_axis[first:last], s1[0,:], color='#4C6EBA', alpha=0.9))
for i in xrange(1, len(s1), 1):
	ax1.plot(saxs_axis[first:last], s1[i,:], color='#4C6EBA', alpha=0.1)	

l.append(ax1.plot(saxs_axis[first:last],saxs_data,'k.',linewidth=1.5))

leg = ax1.legend(tuple(l), tuple(["double disks", "circles","tetrahedra","SAXS"]), loc='upper right')
leg.get_frame().set_alpha(0)

ax1.set_ylabel('log((I(Q))')
ax1.set_ylim(np.min(train_set), np.max(train_set))
ax1.get_xaxis().set_visible(False)

# Plot the feature importances of the forest
ax2 = fig.add_subplot(2, 1, 2, sharex=ax1)
l = []
l.append(ax2.plot(saxs_axis[first:last], importances,'k'))

ax2.set_ylim(0, 0.03)
ax1.set_xlim([0, 0.3])

leg = ax2.legend(tuple(l), tuple(["feature importance"]), loc='upper right')
leg.get_frame().set_alpha(0)

fig.subplots_adjust(hspace=0)
ax2.set_xlabel('Q(A^-1)')

plt.show()


